{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Use one of the following commands to start using PyAlink:\n",
      " - useLocalEnv(parallelism, flinkHome=None, config=None)\n",
      " - useRemoteEnv(host, port, parallelism, flinkHome=None, localIp=\"localhost\", config=None)\n",
      "Call resetEnv() to reset environment and switch to another.\n",
      "\n",
      "JVM listening on 127.0.0.1:64158\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "MLEnv(benv=<pyflink.dataset.execution_environment.ExecutionEnvironment object at 0x120796f60>, btenv=<pyflink.table.table_environment.BatchTableEnvironment object at 0x111e94dd8>, senv=<pyflink.datastream.stream_execution_environment.StreamExecutionEnvironment object at 0x120796d68>, stenv=<pyflink.table.table_environment.StreamTableEnvironment object at 0x1208069b0>)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyalink.alink import *\n",
    "resetEnv()\n",
    "useLocalEnv(1, config=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "schema = \"age bigint, workclass string, fnlwgt bigint, education string, \\\n",
    "          education_num bigint, marital_status string, occupation string, \\\n",
    "          relationship string, race string, sex string, capital_gain bigint, \\\n",
    "          capital_loss bigint, hours_per_week bigint, native_country string, label string\"\n",
    "\n",
    "adult_batch = CsvSourceBatchOp() \\\n",
    "    .setFilePath(\"https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/adult_train.csv\") \\\n",
    "    .setSchemaStr(schema)\n",
    "\n",
    "adult_stream = CsvSourceStreamOp() \\\n",
    "    .setFilePath(\"https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/adult_test.csv\") \\\n",
    "    .setSchemaStr(schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 特征建模"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "categoricalColNames = [\"workclass\", \"education\", \"marital_status\", \"occupation\",\n",
    "                       \"relationship\", \"race\", \"sex\", \"native_country\"]\n",
    "numerialColNames = [\"age\", \"fnlwgt\", \"education_num\", \"capital_gain\",\n",
    "                    \"capital_loss\", \"hours_per_week\"]\n",
    "onehot = OneHotEncoder().setSelectedCols(categoricalColNames) \\\n",
    "        .setOutputCols([\"output\"]).setReservedCols(numerialColNames + [\"label\"])\n",
    "assembler = VectorAssembler().setSelectedCols([\"output\"] + numerialColNames) \\\n",
    "        .setOutputCol(\"vec\").setReservedCols([\"label\"])\n",
    "pipeline = Pipeline().add(onehot).add(assembler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练+预测+评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "logistic = LogisticRegression().setVectorCol(\"vec\").setLabelCol(\"label\") \\\n",
    "        .setPredictionCol(\"pred\").setPredictionDetailCol(\"detail\")\n",
    "model = pipeline.add(logistic).fit(adult_batch)\n",
    "\n",
    "predictBatch = model.transform(adult_batch)\n",
    "\n",
    "metrics = EvalBinaryClassBatchOp().setLabelCol(\"label\") \\\n",
    "        .setPredictionDetailCol(\"detail\").linkFrom(predictBatch).collectMetrics()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 输出评估结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC: 0.9066240193960077\n",
      "KS: 0.6495268264606959\n",
      "PRC: 0.7662328278289783\n",
      "Precision: 0.733230531996916\n",
      "Recall: 0.6064277515623008\n",
      "F1: 0.6638280050258272\n",
      "ConfusionMatrix: [[4755, 1730], [3086, 22990]]\n",
      "LabelArray: ['>50K', '<=50K']\n",
      "LogLoss: 0.3192012545654014\n",
      "TotalSamples: 32561\n",
      "ActualLabelProportion: [0.2408095574460244, 0.7591904425539756]\n",
      "ActualLabelFrequency: [7841, 24720]\n",
      "Accuracy: 0.8520929946868954\n",
      "Kappa: 0.5701036372627706\n"
     ]
    }
   ],
   "source": [
    "print(\"AUC:\", metrics.getAuc())\n",
    "print(\"KS:\", metrics.getKs())\n",
    "print(\"PRC:\", metrics.getPrc())\n",
    "print(\"Precision:\", metrics.getPrecision())\n",
    "print(\"Recall:\", metrics.getRecall())\n",
    "print(\"F1:\", metrics.getF1())\n",
    "print(\"ConfusionMatrix:\", metrics.getConfusionMatrix())\n",
    "print(\"LabelArray:\", metrics.getLabelArray())\n",
    "print(\"LogLoss:\", metrics.getLogLoss())\n",
    "print(\"TotalSamples:\", metrics.getTotalSamples())\n",
    "print(\"ActualLabelProportion:\", metrics.getActualLabelProportion())\n",
    "print(\"ActualLabelFrequency:\", metrics.getActualLabelFrequency())\n",
    "print(\"Accuracy:\", metrics.getAccuracy())\n",
    "print(\"Kappa:\", metrics.getKappa())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
